package cassandra

import (
	"strings"

	"github.com/antlr4-go/antlr/v4"
	"github.com/bytebase/parser/cql"

	storepb "github.com/bytebase/bytebase/backend/generated-go/store"
	"github.com/bytebase/bytebase/backend/plugin/parser/base"
	"github.com/bytebase/bytebase/backend/utils"
)

func init() {
	base.RegisterParseFunc(storepb.Engine_CASSANDRA, parseCassandraForRegistry)
	base.RegisterParseStatementsFunc(storepb.Engine_CASSANDRA, parseCassandraStatements)
}

// parseCassandraForRegistry is the ParseFunc for Cassandra.
// Returns []base.AST with *ANTLRAST instances.
func parseCassandraForRegistry(statement string) ([]base.AST, error) {
	parseResults, err := ParseCassandraSQL(statement)
	if err != nil {
		return nil, err
	}
	return toAST(parseResults), nil
}

// toAST converts []*ParseResult to []base.AST.
func toAST(results []*base.ParseResult) []base.AST {
	var asts []base.AST
	for _, r := range results {
		asts = append(asts, &base.ANTLRAST{
			StartPosition: &storepb.Position{Line: int32(r.BaseLine) + 1},
			Tree:          r.Tree,
			Tokens:        r.Tokens,
		})
	}
	return asts
}

// parseCassandraStatements is the ParseStatementsFunc for Cassandra.
// Returns []Statement with both text and AST populated.
func parseCassandraStatements(statement string) ([]base.Statement, error) {
	// First split to get SingleSQL with text and positions
	singleSQLs, err := SplitSQL(statement)
	if err != nil {
		return nil, err
	}

	// Then parse to get ASTs
	parseResults, err := ParseCassandraSQL(statement)
	if err != nil {
		return nil, err
	}

	// Combine: SingleSQL provides text/positions, ParseResult provides AST
	var statements []base.Statement
	astIndex := 0
	for _, sql := range singleSQLs {
		stmt := base.Statement{
			Text:            sql.Text,
			Empty:           sql.Empty,
			StartPosition:   sql.Start,
			EndPosition:     sql.End,
			ByteOffsetStart: sql.ByteOffsetStart,
			ByteOffsetEnd:   sql.ByteOffsetEnd,
		}
		if !sql.Empty && astIndex < len(parseResults) {
			stmt.AST = &base.ANTLRAST{
				StartPosition: &storepb.Position{Line: int32(parseResults[astIndex].BaseLine) + 1},
				Tree:          parseResults[astIndex].Tree,
				Tokens:        parseResults[astIndex].Tokens,
			}
			astIndex++
		}
		statements = append(statements, stmt)
	}

	return statements, nil
}

// ParseCassandraSQL parses the given CQL statement by using antlr4. Returns a list of AST and token stream if no error.
func ParseCassandraSQL(statement string) ([]*base.ParseResult, error) {
	stmts, err := SplitSQL(statement)
	if err != nil {
		return nil, err
	}

	var result []*base.ParseResult
	for _, stmt := range stmts {
		if stmt.Empty {
			continue
		}

		parseResult, err := parseSingleCassandraSQL(stmt.Text, stmt.BaseLine)
		if err != nil {
			return nil, err
		}
		result = append(result, parseResult)
	}

	return result, nil
}

func parseSingleCassandraSQL(statement string, baseLine int) (*base.ParseResult, error) {
	statement = strings.TrimRightFunc(statement, utils.IsSpaceOrSemicolon) + "\n;"
	inputStream := antlr.NewInputStream(statement)
	lexer := cql.NewCqlLexer(inputStream)
	stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)
	p := cql.NewCqlParser(stream)

	// Remove default error listener and add our own error listener.
	startPosition := &storepb.Position{Line: int32(baseLine) + 1}
	lexer.RemoveErrorListeners()
	lexerErrorListener := &base.ParseErrorListener{
		Statement:     statement,
		StartPosition: startPosition,
	}
	lexer.AddErrorListener(lexerErrorListener)

	p.RemoveErrorListeners()
	parserErrorListener := &base.ParseErrorListener{
		Statement:     statement,
		StartPosition: startPosition,
	}
	p.AddErrorListener(parserErrorListener)

	p.BuildParseTrees = true

	tree := p.Root()

	if lexerErrorListener.Err != nil {
		return nil, lexerErrorListener.Err
	}

	if parserErrorListener.Err != nil {
		return nil, parserErrorListener.Err
	}

	result := &base.ParseResult{
		Tree:     tree,
		Tokens:   stream,
		BaseLine: baseLine,
	}

	return result, nil
}
